import plotly.offline as pyo
from plotly.graph_objs import *
import chart_studio.plotly as py
import pandas as pd
from pandas import DataFrame
from plotly import tools
pyo.offline.init_notebook_mode()
iris = pd.read_csv(r"../Data/irisDataset.csv", index_col = 0)
iris.head()
| Sepal length | Sepal width | Petal length | Petal width | Species | |
|---|---|---|---|---|---|
| 0 | 5.1 | 3.5 | 1.4 | 0.2 | I. setosa |
| 1 | 4.9 | 3.0 | 1.4 | 0.2 | I. setosa |
| 2 | 4.7 | 3.2 | 1.3 | 0.2 | I. setosa |
| 3 | 4.6 | 3.1 | 1.5 | 0.2 | I. setosa |
| 4 | 5.0 | 3.6 | 1.4 | 0.2 | I. setosa |
def scatterplotMatrix(df, scatterColumns, categoricalColumn, colours, title):
"""
This function create a scatterplot matrix and expects the following inputs:
- df - The DataFrame which contains the data
- scatterColumns - a list of the columns in the DataFrame which we want to plot on a scatterplot matrix
- categoricalColumn - the column which contains the categories of data which should be plotted
- colours - a list of colours equal in length to the number of categories in the categoricalColumn
- title - the title of the chart
This function does not create a scatterplot where the same variable intersects with itself.
"""
categories = list(df[categoricalColumn].unique())
colourLookup = dict(zip(categories, colours))
fig = tools.make_subplots(rows = len(scatterColumns),
cols = len(scatterColumns),
print_grid = True,
shared_xaxes = True,
shared_yaxes = True)
diff = max(df[scatterColumns].max()) - min(df[scatterColumns].min())
minimum = min(df[scatterColumns].min()) - (diff * 0.1)
maximum = max(df[scatterColumns].max()) + (diff * 0.1)
for i, column in enumerate(scatterColumns):
fig['layout']['xaxis{}'.format(i + 1)].update({'title' : column,
'range' : [minimum,maximum]})
for j, row in enumerate(scatterColumns):
fig['layout']['yaxis{}'.format(i + 1)].update({'title' : row,
'range' : [minimum,maximum]})
if column != row:
if i == 0 and j == 1:
show = True
else:
show = False
for category, colour in colourLookup.items():
fig.append_trace({'type' : 'scatter',
'mode' : 'markers',
'x' : df.loc[df[categoricalColumn] == category, column],
'y' : df.loc[df[categoricalColumn] == category, row],
'marker' : {'color' : colour,
'size' : 3},
'name' : category,
'legendgroup' : category,
'showlegend' : show},
col = i + 1,
row = j + 1)
fig['layout'].update({'title' : title,
'height' : len(scatterColumns * 200),
'width' : len(scatterColumns * 200)})
pyo.iplot(fig)
return fig
irisScatter = scatterplotMatrix(iris,
['Sepal length','Sepal width','Petal length','Petal width'],
'Species',
['purple','orange','green'],
'Scatterplot matrix of Iris dataset')
/Users/josh/opt/anaconda3/lib/python3.9/site-packages/plotly/tools.py:460: DeprecationWarning: plotly.tools.make_subplots is deprecated, please use plotly.subplots.make_subplots instead
This is the format of your plot grid: [ (1,1) x,y ] [ (1,2) x2,y2 ] [ (1,3) x3,y3 ] [ (1,4) x4,y4 ] [ (2,1) x5,y5 ] [ (2,2) x6,y6 ] [ (2,3) x7,y7 ] [ (2,4) x8,y8 ] [ (3,1) x9,y9 ] [ (3,2) x10,y10 ] [ (3,3) x11,y11 ] [ (3,4) x12,y12 ] [ (4,1) x13,y13 ] [ (4,2) x14,y14 ] [ (4,3) x15,y15 ] [ (4,4) x16,y16 ]
irisScatter = scatterplotMatrix(iris,
['Petal length','Petal width'],
'Species',
['purple','orange','green'],
'Scatterplot matrix of Iris dataset')
This is the format of your plot grid: [ (1,1) x,y ] [ (1,2) x2,y2 ] [ (2,1) x3,y3 ] [ (2,2) x4,y4 ]
iris['noCat'] = 'Iris'
irisScatter = scatterplotMatrix(iris,
['Petal length','Petal width'],
'noCat',
['purple'],
'Scatterplot matrix of Iris dataset')
This is the format of your plot grid: [ (1,1) x,y ] [ (1,2) x2,y2 ] [ (2,1) x3,y3 ] [ (2,2) x4,y4 ]